{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['finetune-t5-super-tiny-standard-bahasa-cased/checkpoint-1730000',\n",
       " 'finetune-t5-super-tiny-standard-bahasa-cased/checkpoint-1740000',\n",
       " 'finetune-t5-super-tiny-standard-bahasa-cased/checkpoint-1750000']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "checkpoints = sorted(glob('finetune-t5-super-tiny-standard-bahasa-cased/checkpoint-*'))\n",
    "checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-small-standard-bahasa-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = T5ForConditionalGeneration.from_pretrained(checkpoints[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "string1 = 'huseinsukamakan ayam,dia sgtrisaukan'\n",
    "string2 = 'drmahathir sangat menekankan budaya budakzamansekarang'\n",
    "string3 = 'ceritatunnajibrazak'\n",
    "string4 = 'TunM sukakan'\n",
    "string_hard = 'IPOH-AhliDewanUndangan Negeri(ADUN) HuluKinta, MuhamadArafat Varisai Mahamadmenafikanmesejtularmendakwa beliau akan melompatparti menyokong UMNO membentuk kerajaannegeridiPerak.BeliauyangjugaKetua Penerangan Parti Keadilan Rakyat(PKR)Perak dalam satumesejringkaskepadaSinar Harian menjelaskan perkara itutidakbenarsama sekali.'\n",
    "string_socialmedia = 'aqxsukalah apeyg tejadidekat mamattu'\n",
    "string5 = 'ihate chicken, but ilike fish'\n",
    "string6 = 'Higuys! I noticedsemalam & harini dahramai yangdapat cookiesni kan. So hariniinak sharesome post mortemof our first batch:'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "strings = [\n",
    "    string1,\n",
    "    string2,\n",
    "    string3,\n",
    "    string4,\n",
    "    string_hard,\n",
    "    string_socialmedia,\n",
    "    string5,\n",
    "    string6\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['husein suka makan ayam, dia sgt rasakan',\n",
       " 'dr mahathir sangat menekankan budaya budak zaman sekarang',\n",
       " 'cerita tun najib razak',\n",
       " 'Tun M sukakan',\n",
       " 'IPOH - Ahli Dewan Undangan Negeri (ADUN) Hulu Kinta, Muhamad Arafat Varisai Mahamad menafikan mesej tular mendakwa beliau akan melompat parti menyokong UMNO membentuk kerajaan negeri di Perak. Beliau yang juga Ketua Penerangan Parti Keadilan Rakyat (PKR) Perak dalam satu mesej ringkas kepada Sinar Harian menjelaskan perkara itu tidak benar sama sekali.',\n",
       " 'aq x sukalah ape yg tejadi dekat mamat tu',\n",
       " 'i hate chicken, but i like fish',\n",
       " 'Hi guys! I noticed semalam & harini dah ramai yang dapat cookies ni kan. So harini i nak share some post mortem of our first batch:']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids = [{'input_ids': tokenizer.encode(f'segmentasi: {s}', return_tensors='pt')[\n",
    "    0]} for s in strings]\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model.generate(**padded, max_length=256)\n",
    "tokenizer.batch_decode(outputs, skip_special_tokens=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d96a3ca7d9074bedb90339f2fc7dd824",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file pytorch_model.bin:   0%|          | 32.0k/48.4M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity, may be slow...        \n",
      "remote: LFS file scan complete.        \n",
      "To https://huggingface.co/mesolitica/finetune-segmentation-t5-super-tiny-standard-bahasa-cased\n",
      "   f52c418..aab9233  main -> main\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/mesolitica/finetune-segmentation-t5-super-tiny-standard-bahasa-cased/commit/aab923345df76369cbc9e90301e17fc77e9eb592'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub('finetune-segmentation-t5-super-tiny-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.push_to_hub('finetune-segmentation-t5-super-tiny-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('test-set-segmentation.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_cer(actual, hyp):\n",
    "    \"\"\"\n",
    "    Calculate CER using `python-Levenshtein`.\n",
    "    \"\"\"\n",
    "    import Levenshtein as Lev\n",
    "\n",
    "    actual = actual.replace(' ', '')\n",
    "    hyp = hyp.replace(' ', '')\n",
    "    return Lev.distance(actual, hyp) / len(actual)\n",
    "\n",
    "def calculate_wer(actual, hyp):\n",
    "    \"\"\"\n",
    "    Calculate WER using `python-Levenshtein`.\n",
    "    \"\"\"\n",
    "    import Levenshtein as Lev\n",
    "\n",
    "    b = set(actual.split() + hyp.split())\n",
    "    word2char = dict(zip(b, range(len(b))))\n",
    "\n",
    "    w1 = [chr(word2char[w]) for w in actual.split()]\n",
    "    w2 = [chr(word2char[w]) for w in hyp.split()]\n",
    "\n",
    "    return Lev.distance(''.join(w1), ''.join(w2)) / len(actual.split())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [1:11:00<00:00,  2.35it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "wer, cer = [], []\n",
    "for i in tqdm(range(len(data[:10000]))):\n",
    "    input_ids = [{'input_ids': tokenizer.encode(f'segmentasi: {data[i][0]}', return_tensors='pt')[0]}]\n",
    "    padded = tokenizer.pad(input_ids, padding='longest')\n",
    "    \n",
    "#     for k in padded.keys():\n",
    "#         padded[k] = padded[k].cuda()\n",
    "    \n",
    "    outputs = model.generate(**padded, max_length=256)\n",
    "    predicted = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]\n",
    "    actual = data[i][1]\n",
    "    wer.append(calculate_wer(actual, predicted))\n",
    "    cer.append(calculate_cer(actual, predicted))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.030962535411709086, 0.004112925345819103)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "np.mean(wer), np.mean(cer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
